Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3 hybrid implementation using submeshes #18777

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

ipotkonjak-tt
Copy link
Contributor

@ipotkonjak-tt ipotkonjak-tt commented Mar 7, 2025

Problem description

Missing support for data / hybrid parallelism for Llama3 models.

What's changed

Addition of hybrid parallelism within llama code base with concept of submeshes. Implementation is mainly based at the LlamaGenerator level. MeshDevice is partitioned into submeshes where each subset of devices has an independent model. Models remain implemented in the tensor parallel manner.

Checklist

Copy link
Contributor

@yieldthought yieldthought left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean 👌

To do:

  • Add at least one CI test that will exercise DP. I suggest adding a demo to the t3k tests.

@ipotkonjak-tt ipotkonjak-tt requested a review from cfjchu March 7, 2025 14:34
@ipotkonjak-tt ipotkonjak-tt self-assigned this Mar 8, 2025
Comment on lines +452 to +453
if is_ci_env and num_devices == 8 and data_parallel > 1 and not ("3.2-1B" in llama_dir or "3.1-8B" in llama_dir):
pytest.skip("CI runs only hybrid Llama3 1b and 8b on T3K")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about 3B?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to avoid burdening the CI with additional tests. 1B and 8B seemed okay to cover perf regression checks as the smallest and biggest variants of the smaller Llama3 models. Should we add 3B anyway?

return data_parallel, mesh_device.create_submeshes(ttnn.MeshShape(1, num_devices // data_parallel))


def allocate_kv_cache(kv_cache_shape, dtype, num_layers, mesh_device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO (@ipotkonjak-tt and/or @skhorasganiTT) Modify KV creation in vLLM to use this function and test with DP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants